/* -LICENSE-START-
** Copyright (c) 2017 Blackmagic Design
**
** Permission is hereby granted, free of charge, to any person or organization
** obtaining a copy of the software and accompanying documentation covered by
** this license (the "Software") to use, reproduce, display, distribute,
** execute, and transmit the Software, and to prepare derivative works of the
** Software, and to permit third-parties to whom the Software is furnished to
** do so, all subject to the following:
**
** The copyright notices in the Software and this entire statement, including
** the above license grant, this restriction and the following disclaimer,
** must be included in all copies of the Software, in whole or in part, and
** all derivative works of the Software, unless such copies or derivative
** works are solely in the form of machine-executable object code generated by
** a source language processor.
**
** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
** FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
** SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
** FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
** ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
** DEALINGS IN THE SOFTWARE.
** -LICENSE-END-
*/
#include "VTDecodeSession.h"
#include <CoreMedia/CMFormatDescription.h>
#include <CoreMedia/CMBlockBuffer.h>
#include <cstdio>

#define BAIL_IF(cond, fmt, args...)                                        \
if (cond)                                                                  \
{                                                                          \
	fprintf(stderr, "[%s:%3d] %s: " fmt, __FILE__, __LINE__, __FUNCTION__, ##args); \
	goto bail;                                                             \
}

typedef enum
{
	kTypeSliceNotIDR	= 1,
	kTypeSliceA			= 2,
	kTypeSliceB			= 3,
	kTypeSliceC			= 4,
	kTypeSliceIDR		= 5,
	kTypeSPS			= 7,	// seq_parameter_set_rbsp
	kTypePPS			= 8,	// pic_parameter_set_rbsp
} NalType;

static const int kNalTypeOffset = 4; // Within a data with size NAL packet
static const int kNalPrefixSize = 4;

// H264 helpers : useful functions to extract the NAL unit type
static bool isIDR(uint8_t* nalDataWithSize)
{
	return ((NalType)(nalDataWithSize[kNalTypeOffset] & 0x1F) == kTypeSliceIDR);
}

static bool isSlice(uint8_t* nalDataWithSize)
{
	NalType type = (NalType)(nalDataWithSize[kNalTypeOffset] & 0x1F);
	return (type >= kTypeSliceNotIDR && type <= kTypeSliceIDR);
}

VTDecodeSession::VTDecodeSession() :
	m_delegate(),
	m_session(NULL),
	m_videoFormatDescription(NULL),
	m_needsSPS(true),
	m_needsPPS(true),
	m_needsIDR(true),
	m_NALsForOneFrameTotalSize(0),
	m_firstSPS(NULL),
	m_firstPPS(NULL),
	m_nalParser(NULL)
{
	m_hostTimeScale = ::CVGetHostClockFrequency();

	m_nalParser = CreateBMDStreamingH264NALParser();

	if (m_nalParser == NULL)
		fprintf(stderr, "[%s:%3d] %s: " "Error: couldn't load NAL parser\n", __FILE__, __LINE__, __FUNCTION__);
}

VTDecodeSession::~VTDecodeSession()
{
	if (m_nalParser)
		m_nalParser->Release();
	if (m_firstSPS)
		m_firstSPS->Release();
	if (m_firstPPS)
		m_firstPPS->Release();

	if (m_session != NULL)
	{
		::VTDecompressionSessionInvalidate(m_session);
		::CFRelease(m_session);
		::CFRelease(m_videoFormatDescription);
	}
	clearNalsFrameList();
}

void VTDecodeSession::setDelegate(VTDecodeDelegate* delegate)
{
	m_delegate = delegate;
}

bool VTDecodeSession::startDecompressionSession()
{
	OSStatus 							status;
	CFMutableDictionaryRef				pixBufAttributes;
	CFMutableDictionaryRef				decoderAttributes;
	VTDecompressionOutputCallbackRecord	callbackRecord;
	long								format;
	CFNumberRef 						number;
	CMVideoDimensions  					dimensions;

	callbackRecord.decompressionOutputCallback = videoDecompressed;
	callbackRecord.decompressionOutputRefCon = this;
	m_session = nil;

	uint8_t* parameterSetPointers[2];
	size_t parameterSetSizes[2];

	parameterSetSizes[0] = m_firstSPS->GetPayloadSize();
	m_firstSPS->GetBytes((void**)&parameterSetPointers[0]);

	parameterSetSizes[1] = m_firstPPS->GetPayloadSize();
	m_firstPPS->GetBytes((void**)&parameterSetPointers[1]);


	status = ::CMVideoFormatDescriptionCreateFromH264ParameterSets(NULL, 2, parameterSetPointers, parameterSetSizes, 4, &m_videoFormatDescription);
	BAIL_IF(status != noErr, "ERROR: CMVideoFormatDescriptionCreateFromH264ParameterSets %x\n", status);

	dimensions = ::CMVideoFormatDescriptionGetDimensions(m_videoFormatDescription);

	pixBufAttributes = ::CFDictionaryCreateMutable(NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks);

	number = ::CFNumberCreate(NULL, kCFNumberShortType, &dimensions.width);
	::CFDictionaryAddValue(pixBufAttributes, kCVPixelBufferWidthKey, number);
	::CFRelease(number);

	number = ::CFNumberCreate(NULL, kCFNumberShortType, &dimensions.height);
	::CFDictionaryAddValue(pixBufAttributes, kCVPixelBufferHeightKey, number);
	::CFRelease(number);

	format = k2vuyPixelFormat;
	number = ::CFNumberCreate(NULL, kCFNumberSInt32Type, &format);
	::CFDictionaryAddValue(pixBufAttributes, kCVPixelBufferPixelFormatTypeKey, number);
	::CFRelease(number);

	decoderAttributes = ::CFDictionaryCreateMutable(NULL, 0, &kCFTypeDictionaryKeyCallBacks, &kCFTypeDictionaryValueCallBacks);

	::CFDictionarySetValue(decoderAttributes, kVTVideoDecoderSpecification_EnableHardwareAcceleratedVideoDecoder, kCFBooleanTrue);

	status = ::VTDecompressionSessionCreate(NULL, m_videoFormatDescription, decoderAttributes, pixBufAttributes, &callbackRecord, &m_session);
	::CFRelease(pixBufAttributes);
	::CFRelease(decoderAttributes);
	BAIL_IF(status != noErr, "VTDecompressionSessionCreate returned %i, session %p\n", status, m_session);

	status = ::VTSessionSetProperty(m_session, kVTDecompressionPropertyKey_FieldMode, kVTDecompressionProperty_FieldMode_DeinterlaceFields);
	BAIL_IF(status != noErr, "VTSessionSetProperty returned %i, session %p\n", status, m_session);

	return true;

bail:
	if (m_videoFormatDescription != NULL)
	{
		::CFRelease(m_videoFormatDescription);
		m_videoFormatDescription = NULL;
	}

	if (m_session != NULL)
	{
		::VTDecompressionSessionInvalidate(m_session);
		::CFRelease(m_session);
		m_session = NULL;
	}

	return false;
}

void VTDecodeSession::handleH264NAL(IBMDStreamingH264NALPacket* nal)
{
	static const int kMaxNalsInFrameList = 256; // arbitrary
	uint8_t* nalData = NULL;
	HRESULT result;
	uint64_t displayTime;
	uint64_t lastDisplayTime;

	BAIL_IF (! m_nalParser, "Error: NAL parser not loaded\n");
	BAIL_IF (! nal, "Error: nal packet is not valid\n");

	if (! m_videoFormatDescription)
	{
		// Skip all NALs until we have both SPS and PPS, then we can set up the
		// decompression session.
		if (m_needsSPS && m_nalParser->IsNALSequenceParameterSet(nal) == S_OK)
		{
			m_firstSPS = nal;
			m_firstSPS->AddRef();
			m_needsSPS = false;
		}
		
		if (m_needsPPS && m_nalParser->IsNALPictureParameterSet(nal) == S_OK)
		{
			m_firstPPS = nal;
			m_firstPPS->AddRef();
			m_needsPPS = false;
		}
		
		if (! m_firstSPS || ! m_firstPPS)
			goto bail;

		if (! startDecompressionSession())
			goto bail;

		m_firstSPS->Release();
		m_firstPPS->Release();
		m_firstSPS = NULL;
		m_firstPPS = NULL;
		m_needsIDR = true;
	}

	m_nalDecompressed = false;
	m_pixBuf = NULL;

	nal->GetBytesWithSizePrefix((void**)&nalData);

	// No point trying to decompress until the first IDR
	if (m_needsIDR)
	{
		if (! isIDR(nalData))
			goto bail;

		clearNalsFrameList();
		m_needsIDR = false;
	}
	
	// VideoToolbox expects a group of NALs representing one frame.
	// Group the NALs based on their display time.
	result = nal->GetDisplayTime(m_hostTimeScale, &displayTime);
	BAIL_IF(FAILED(result), "Error: No display time for nal (%x)\n", result);
	
	if (! m_NALsForOneFrame.empty())
	{
		result = m_NALsForOneFrame.back()->GetDisplayTime(m_hostTimeScale, &lastDisplayTime);
		BAIL_IF(FAILED(result), "Error: No display time for nal (%x)\n", result);
	}
	else
	{
		lastDisplayTime = displayTime;
	}
	
	if ((displayTime != lastDisplayTime) && ! m_NALsForOneFrame.empty())
	{
		decompressFrame();
	}
	else if (isSlice(nalData))
	{
		nal->AddRef();
		m_NALsForOneFrame.push_back(nal);
		m_NALsForOneFrameTotalSize += nal->GetPayloadSize() + kNalPrefixSize;

		// For safety: if we're seeing too many slices but no new frame, start again from the next IDR
		if (m_NALsForOneFrame.size() == kMaxNalsInFrameList)
		{
			m_needsIDR = true;
		}
	}
	
bail:
	return;
}

void VTDecodeSession::decompressFrame()
{
	CMBlockBufferRef  blockBuffer = NULL;
	CMSampleBufferRef sampleBuffer = NULL;
	OSStatus status;
	uint8_t* nalData = NULL;
	size_t nalSize;
	size_t pos = 0;

	status = ::CMBlockBufferCreateWithMemoryBlock(kCFAllocatorDefault, NULL, m_NALsForOneFrameTotalSize, NULL, NULL, 0, m_NALsForOneFrameTotalSize, kCMBlockBufferAssureMemoryNowFlag, &blockBuffer);
	BAIL_IF(status != kCMBlockBufferNoErr, "CMBlockBufferCreateWithMemoryBlock failed (%i)\n", status);

	// Group all the NAL slices belonging to the same frame into one single SampleBuffer to decompress
	for (NalsFrameList::iterator it = m_NALsForOneFrame.begin(); it != m_NALsForOneFrame.end(); ++it)
	{
		IBMDStreamingH264NALPacket* nal = *it;

		nal->GetBytesWithSizePrefix((void**)&nalData);
		nalSize = nal->GetPayloadSize() + kNalPrefixSize;

		status = ::CMBlockBufferReplaceDataBytes(nalData, blockBuffer, pos, nalSize);
		BAIL_IF(status != kCMBlockBufferNoErr, "CMBlockBufferReplaceDataBytes failed (%i)\n", status);
		pos += nalSize;
	}
	status = ::CMSampleBufferCreate(NULL, blockBuffer, true, NULL, NULL, m_videoFormatDescription, 1, 0, NULL, 1, &m_NALsForOneFrameTotalSize, &sampleBuffer);
	BAIL_IF(status, "CMSampleBufferCreate failed (%i)\n", status);

	// The flag for asynchronous decompression is explicitly not passed here to ensure that
	// VTDecompressionSessionDecodeFrame blocks until the callback is called, allowing this application
	// to process frames serially.
	status = ::VTDecompressionSessionDecodeFrame(m_session, sampleBuffer, 0, NULL, NULL);
	BAIL_IF(status, "VTDecompressionSessionDecodeFrame failed (%i)\n", status);
	BAIL_IF(! m_nalDecompressed, "Unexpected error: nal decompressed should be true here\n");

	if (m_pixBuf)
	{
		m_delegate->haveVideoFrame(m_pixBuf, m_NALsForOneFrame.back());

		::CVPixelBufferRelease(m_pixBuf);
		m_pixBuf = NULL;
	}

bail:

	if (sampleBuffer)
		::CFRelease(sampleBuffer);

	if (blockBuffer)
		::CFRelease(blockBuffer);

	clearNalsFrameList();
}

void VTDecodeSession::clearNalsFrameList()
{
	m_NALsForOneFrameTotalSize = 0;
	while (! m_NALsForOneFrame.empty())
	{
		m_NALsForOneFrame.front()->Release();
		m_NALsForOneFrame.pop_front();
	}
}

void VTDecodeSession::decompressionCompleteWithResult(OSStatus result, VTDecodeInfoFlags flags, CVPixelBufferRef pixBuf)
{
	if (result == noErr)
	{
		::CVPixelBufferRetain(pixBuf);
		m_pixBuf = pixBuf;
	}
	else
	{
		m_pixBuf = NULL;
	}

	m_nalDecompressed = true;
}

void VTDecodeSession::videoDecompressed(void* vtDecodeSession,
										void* sourceFrameRefCon,
										OSStatus result,
										VTDecodeInfoFlags flags,
										CVImageBufferRef pixBuf,
										CMTime time,
										CMTime duration)
{
	VTDecodeSession* realSelf = (VTDecodeSession*)vtDecodeSession;
	realSelf->decompressionCompleteWithResult(result, flags, pixBuf);
}
